import asyncio

from scipy import stats

from devicepilot.pylog.pylogger import PyLogger
from devicepilot.export import pli
from devicepilot.config_enum import hal_enum
from devicepilot.common.helper import get_iso_date

from config_enum import eef_measurement_unit_enum as meas_enum
from config_enum import od_filter_wheel_enum as od_enum
from config_enum import detector_aperture_slider_enum as das_enum

from hw_abstraction.hal import HAL

from urpc.eefmeasurementparameter import EEFMeasurementParameter

# from predefined_tasks.common.list_calculations import plot_2d_list
# from predefined_tasks.common.helper import send_to_gc


hal: HAL = pli.hal
meas_unit = hal.measurement_unit
das1 = hal.detector_aperture_slider1
od1 = hal.od_filter_wheel1
od2 = hal.od_filter_wheel2


async def fi_linearization(pmt_serial_no: str, hv=0.35, ahrs=101.0, ace=1.0):
    exc_ms = 100
    low_power = True
    position_factors = (
        (7, 8), (7, 7), (7, 6), (7, 5), (7, 4), (7, 3), (7, 2), (7, 1),
        (6, 8), (6, 7), (6, 6), (6, 5), (6, 4), (6, 3), (6, 2), (6, 1),
        (5, 8), (5, 7), (5, 6), (5, 5), (5, 4), (5, 3), (5, 2), (5, 1),
        (4, 8), (4, 7), (4, 6), (4, 5), (4, 4), (4, 3), (4, 2), (4, 1),
        (3, 8), (3, 7), (3, 6), (3, 5), (3, 4), (3, 3), (3, 2), (3, 1),
        (2, 8), (2, 7), (2, 6), (2, 5), (2, 4), (2, 3), (2, 2), (2, 1),
        (1, 8), (1, 7), (1, 6), (1, 5), (1, 4), (1, 3), (1, 2), (1, 1),
    )

    try:
        await hal.startup_hardware()
        await hal.initialize_device()
        await hal.home_movers()

        await meas_unit.set_config(meas_enum.PMT1.HighVoltageSettingFI, hv)
        await meas_unit.set_config(meas_enum.PMT1.AnalogHighRangeScale, ahrs)
        await meas_unit.set_config(meas_enum.PMT1.AnalogCountingEquivalent, ace)
        await meas_unit.enable_pmt_hv_fi()
        await meas_unit.enable_flash_lamp_power(low_power)

        meas_unit.clear_measurements()
        guid = "6131ae47-f0c2-4d12-bfa7-b24fa13c4ac1"
        await meas_unit.load_fi_measurement(guid, measurement_time=exc_ms)

        PyLogger.logger.info("FI Operation loaded without setting optical path")

        refs = ""
        hv1 = await meas_unit.endpoint.get_parameter(EEFMeasurementParameter.PMT1HighVoltageSetting)
        signal_str = str(hv1)
        # await send_to_gc(signal_str)
        refs += signal_str + "\n"
        PyLogger.logger.info(refs)

        await das1.move_to_named_position(das_enum.Positions.Aperture30)

        od_1_offset = od1.get_config(od_enum.Positions.Offset)
        od_2_offset = od2.get_config(od_enum.Positions.Offset)
        data = []
        lin = None
        try:
            for od_1_pos_fac, od_2_pos_fac in position_factors:
                od_1_pos = od_1_offset + (od_1_pos_fac - 1) * 45
                od_2_pos = od_2_offset + (od_2_pos_fac - 1) * 45

                await asyncio.gather(od1.move(od_1_pos), od2.move(od_2_pos))

                await meas_unit.execute_measurement(guid)
                fi_res = await meas_unit.read_fi_measurement_results(guid, measurement_time=exc_ms, iref0=100000)
                signal = fi_res._pmt1_analog_total
                signal_str = str(signal)
                refs += signal_str + "\n"
                # await send_to_gc(signal_str)
                PyLogger.logger.info(refs)
                data.append(signal)

            lin = FILinearization(data)
            lin.find_plateau()
            lin.calc_linear_regression()
            lin.calc_linear_values()
            lin.compare_plateau_measured_values()
            # lin.to_csv("test.csv")
            factors = lin.create_correction_factors()
            correction = lin.create_correction_values_str(factors)
            # await send_to_gc(f"\nf(x)={lin.slope} x + {lin.intercept}")
            write_corrected_values_to_file(str(pmt_serial_no), correction)
        except Exception as e:
            # await send_to_gc("Error in fi-lin:" + exception_to_string(e), log=True, error=True)
            raise
        finally:
            if lin and len(lin.plateau_values):
                plot_list = [[i, lin.plateau_values[i]] for i in range(len(lin.plateau_values))]
                # await plot_2d_list(plot_list, f"FI Linearization of {pmt_serial_no}", "", "")
    finally:
        await hal.shutdown_hardware()

    return 'fi_linearization done'


class FILinearization(object):
    # Column A
    # TRANSMISSION_VALUES = [4.07989E-08,6.35331E-08,1.06753E-07,1.29128E-07,1.7649E-07,2.16553E-07,2.62276E-07,3.04089E-07,7.72056E-07,1.20226E-06,2.02014E-06,2.44354E-06,3.33979E-06,4.09793E-06,4.96316E-06,5.7544E-06,9.28215E-06,1.44544E-05,2.42874E-05,2.93779E-05,4.01531E-05,4.9268E-05,5.96703E-05,6.91831E-05,6.72432E-05,0.000104713,0.000175947,0.000212824,0.000290883,0.000356915,0.000432273,0.000501187,0.001084143,0.001688254,0.002836736,0.003431295,0.004689826,0.005754435,0.006969409,0.008080493,0.015201785,0.023672599,0.039776542,0.048113419,0.065760463,0.080688346,0.097724634,0.113304184,0.134167908,0.208929613,0.284282418,0.351059784,0.424639383,0.580388661,0.712139155,0.862498023,1]
    TRANSMISSION_VALUES = [
        4.07989E-08, 6.35331E-08, 1.06753E-07, 1.29128E-07, 1.76490E-07, 2.16553E-07, 2.62276E-07, 3.04089E-07,
        7.72056E-07, 1.20226E-06, 2.02014E-06, 2.44354E-06, 3.33979E-06, 4.09793E-06, 4.96316E-06, 5.75440E-06,
        9.28215E-06, 1.44544E-05, 2.42874E-05, 2.93779E-05, 4.01531E-05, 4.92680E-05, 5.96703E-05, 6.91831E-05,
        6.72432E-05, 0.000104713, 0.000175947, 0.000212824, 0.000290883, 0.000356915, 0.000432273, 0.000501187,
        0.001084143, 0.001688254, 0.002836736, 0.003431295, 0.004689826, 0.005754435, 0.006969409, 0.008080493,
        0.015201785, 0.023672599, 0.039776542, 0.048113419, 0.065760463, 0.080688346, 0.097724634, 0.113304184,
        0.134167908, 0.208929613, 0.351059784, 0.424639383, 0.580388661, 0.712139155, 0.862498023, 1.000000000,
    ]

    def __init__(self, measured_values):
        self.amount_values = len(measured_values)
        # Column B
        self.trans_measured_values = []
        for i in range(self.amount_values):
            transmission_value = self.TRANSMISSION_VALUES[i]
            measured_value = measured_values[i] / 100
            self.trans_measured_values.append([transmission_value, measured_value])
        self.trans_measured_values.sort(key=lambda x: x[0])  # Sort by transmission value

        self.plateau_values = None
        self.best_plateau = None   # List with index, calc_value, measured_value
        self.linear_values = None

        # Values for linear function
        self.slope = None
        self.intercept = None

        self.compare_values = []

    def find_plateau(self):
        # Column F
        self.plateau_values = []
        last_plateau = []
        self.best_plateau = []
        last_values = -1, -1
        maximum_diff_perc = 5
        # Search for longest continuous plateau of values with +-5% difference
        for i in range(self.amount_values):
            measured_value = self.trans_measured_values[i][1]
            transmission_value = self.trans_measured_values[i][0]
            current_value = measured_value / transmission_value
            current_average = current_value
            if len(last_plateau):
                sum_last_plateau = 0
                for _, value, __ in last_plateau:
                    sum_last_plateau += value
                current_average = sum_last_plateau / len(last_plateau)
            self.plateau_values.append(current_value)
            if current_value == 0.0:
                diff = 100
            else:
                diff = abs(100 - (100*current_average)/current_value)
            in_range = diff < maximum_diff_perc
            if in_range:
                if len(last_plateau) == 0:
                    last_plateau.append((i-1, last_values[0], last_values[1]))  # Add first value (they all depend on previous)
                last_plateau.append((i, current_value, measured_value))
                if len(last_plateau) > len(self.best_plateau):
                    self.best_plateau = last_plateau
            else:
                last_plateau = []
            last_values = current_value, measured_value
            print(i, current_value, diff, in_range)

        # Get average of that plateau
        plateau_average = self.calc_plateau_average(self.best_plateau)

        # Search for other non-continuous values with +-5%
        plateau_indices = [i[0] for i in self.best_plateau]
        new_plateau = self.get_best_values(plateau_indices, plateau_average, maximum_diff_perc)

        indices = [value[0] for value in new_plateau]
        for i in range(indices[0], indices[-1]):
            if i not in indices:
                PyLogger.logger.warning(f"missing index: {i}")
                measured_value = self.trans_measured_values[i][1]
                transmission_value = self.trans_measured_values[i][0]
                current_value = measured_value / transmission_value
                new_plateau.append((i, current_value, measured_value))  # TODO sort

        new_plateau.sort(key=lambda x: x[0])
        self.best_plateau = new_plateau

    @staticmethod
    def calc_plateau_average(values):
        sum_values = 0
        for _, current_value, __ in values:
            sum_values += current_value
        return sum_values / len(values)

    def get_best_values(self, plateau_indices, average, maximum_diff_perc):
        temp_plateau = []
        for i in range(self.amount_values):
            measured_value = self.trans_measured_values[i][1]
            transmission_value = self.trans_measured_values[i][0]
            current_value = measured_value / transmission_value
            if i not in plateau_indices:
                if current_value == 0.0:
                    diff = 100
                else:
                    diff = abs(100 - (100 * average) / current_value)
                if diff < maximum_diff_perc:
                    plateau_indices.append(i)
                    plateau_indices.sort()
                    temp_plateau.append((i, current_value, measured_value))
            else:
                temp_plateau.append((i, current_value, measured_value))
        return temp_plateau

    def calc_linear_regression(self):
        best_plateau_indices = []
        best_lin_values = []
        for i, _, lin_val in self.best_plateau:
            best_plateau_indices.append(i)
            best_lin_values.append(lin_val)
        start_trans_mission_index = best_plateau_indices[0]
        end_trans_mission_index = best_plateau_indices[-1] + 1
        transmission_values = []
        for i in range(start_trans_mission_index, end_trans_mission_index):
            transmission_values.append(self.trans_measured_values[i][0])
        self.slope, self.intercept, r, p, std_err = stats.linregress(x=transmission_values, y=best_lin_values)
        print("Slope:", self.slope, ", Intercept:", self.intercept)

    def calc_linear_values(self):
        # Column D
        self.linear_values = [value[0] * self.slope + self.intercept for value in self.trans_measured_values]

    def compare_plateau_measured_values(self):
        # Column E
        for i in range(self.amount_values):
            measured_value = self.trans_measured_values[i][1]
            value = 0
            if measured_value != 0.0:
                value = self.linear_values[i] / self.trans_measured_values[i][1]
            self.compare_values.append(value)

    def create_correction_factors(self):
        correction_factor_pair_list = [(0.0, 1.0)]
        last_plateau_value = self.best_plateau[-1]
        correction_factor_pair_list.append((last_plateau_value[2], 1.0))
        for i in range(last_plateau_value[0]+1, self.amount_values):  # Every value after the last plateau value needs to be corrected
            compare_value = self.compare_values[i]
            measured_value = self.trans_measured_values[i][1]
            correction_factor_pair_list.append((measured_value, compare_value))
        return correction_factor_pair_list

    @staticmethod
    def create_correction_values_str(correction_factor_pair_list):
        correction = ""
        for p1, p2 in correction_factor_pair_list:
            for value in p1, p2:
                correction += f"{value:.6f};"
        correction = correction[:-1]  # Remove last semicolon
        return correction

    def to_csv(self, path):
        with open(path, "w") as file:
            for i in range(self.amount_values):
                for value in (self.trans_measured_values[i][0],  # A
                              self.trans_measured_values[i][1],  # B
                              None,                              # C (best_plateau):
                              self.linear_values[i],             # D
                              self.compare_values[i],            # E
                              self.plateau_values[i]             # F
                              ):
                    if value is None:
                        o = 0
                        for plateau_i, plateau_value, lin_value in self.best_plateau:
                            if i == plateau_i:
                                value = lin_value
                                break
                            o += 1
                        if value is None:
                            value = ""
                    file.write(str(value) + ",")
                file.write("\n")
            file.write(f"f(x)={self.slope} x + {self.intercept}")


def write_corrected_values_to_file(pmt_serial_no: str, values):
    iso_date = get_iso_date()[:10]  # Only keep year-month-day (e.g. 2023-02-23)
    adj_filepath = hal.get_config(hal_enum.Application.AdjFilePath)
    if adj_filepath == "":
        adj_filepath = "."  # For simulation
    file_name = "linearization_correction_" + pmt_serial_no + "_" + iso_date + ".pmt"
    with open(adj_filepath + "/" + file_name, "w") as file:
        file.write(values)
